import os
import csv
import glob
import torch
import numpy as np
import collections
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available(): 
    torch.cuda.set_device(device)


def get_base_class_images(root_dir, n_base, experts = 1, image_size=224, overlap = 0):
    """
    Generate base class images from the training dataset.
    
    Args:
        root_dir (str): Path to the dataset root directory
        n_base (int): Number of classes to select
        image_size (int): Size to resize images to (default: 84)
        
    Returns:
        tuple: (class_images, class_labels) where:
            class_images: torch.Tensor of shape [n_way, n_images, 3, image_size, image_size]
            class_labels: list of class names corresponding to each class
    """
    # Load the training CSV file
    csv_path = os.path.join(root_dir, 'train.csv')
    
    # Create a dictionary mapping class names to image paths
    class_dict = collections.defaultdict(list)
    with open(csv_path, 'r') as f:
        reader = csv.reader(f)
        next(reader)  # skip header
        for row in reader:
            img_name, class_name = row
            img_path = os.path.join(root_dir, 'images', img_name)
            class_dict[class_name].append(img_path)
    
    # Get list of all classes and randomly select n_base * experts classes
    class_list = list(class_dict.keys())
    selected_classes = np.random.choice(class_list, n_base * experts, replace=False)
    
    # Setup transforms
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])
    
    # Collect all images for each selected class
    class_images = []
    class_labels = []
    
    for class_name in selected_classes:
        # Get all images for this class
        image_paths = class_dict[class_name][:100]
        class_imgs = []
        
        # Load and transform each image
        for img_path in image_paths:
            img = Image.open(img_path).convert('RGB')
            if transform is not None:
                img = transform(img)
            class_imgs.append(img)
        
        # Stack images for this class and add to the list
        class_images.append(torch.stack(class_imgs))
        class_labels.append(class_name)
    
    # Stack all class images into a single tensor
    class_images = torch.stack(class_images)
    
    return class_images, class_labels


# Example usage:
if __name__ == '__main__':
    # Path to your MiniImagenet dataset
    root_dir = '/root/mini-imagenet'
    
    # Get dataloaders
    train_loader, val_loader, test_loader = get_miniimagenet_dataloaders(
        root_dir=root_dir,
        batch_size=32,
        n_way=5,
        k_shot=1,
        n_query=5
    )
    
    # Example of iterating through the dataloader
    for batch_idx, (support_set, support_labels, query_set, query_labels) in enumerate(train_loader):
        print(f"Batch {batch_idx}:")
        print(f"Support set shape: {support_set.shape}")
        print(f"Support labels shape: {support_labels.shape}")
        print(f"Query set shape: {query_set.shape}")
        print(f"Query labels shape: {query_labels.shape}")
        
        # Break after first batch for demonstration
        if batch_idx == 0:
            break

    # Get base class images
    n_base = 10
    class_images, class_labels = get_base_class_images(root_dir, n_base)
    
    print(f"Base class images shape: {class_images.shape}")
    print(f"Number of classes: {len(class_labels)}")
    print(f"Class labels: {class_labels}")
    print(f"Number of images per class: {class_images.shape[1]}")